{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "db4dc272-88fe-47ad-98fd-b94d4f840dca",
      "metadata": {
        "id": "db4dc272-88fe-47ad-98fd-b94d4f840dca"
      },
      "source": [
        "# PEFT with DNA Language Models"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d381f473-0d37-4b5b-ae9e-d2b32bab7c04",
      "metadata": {
        "id": "d381f473-0d37-4b5b-ae9e-d2b32bab7c04"
      },
      "source": [
        "This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "23f460c3-d7e5-437f-a5e9-d029cd225bf8",
      "metadata": {
        "id": "23f460c3-d7e5-437f-a5e9-d029cd225bf8"
      },
      "source": [
        "### 1. Import relevant libraries"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "29a35f95-738a-4f5e-88ce-dc5f8f9be5dc",
      "metadata": {
        "id": "29a35f95-738a-4f5e-88ce-dc5f8f9be5dc"
      },
      "source": [
        "We'll start by importing the required libraries, including the PEFT library and other dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "0a40abdf-ca1c-436f-a2af-603cd67a45a4",
      "metadata": {
        "id": "0a40abdf-ca1c-436f-a2af-603cd67a45a4"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/opt/homebrew/anaconda3/envs/peft/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n"
          ]
        }
      ],
      "source": [
        "import torch\n",
        "import transformers\n",
        "import peft\n",
        "import tqdm\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "a445f8be-545d-4085-a5f9-c64983655224",
      "metadata": {
        "id": "a445f8be-545d-4085-a5f9-c64983655224"
      },
      "source": [
        "### 2. Load models\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "63782b55-1c38-4e44-b003-e57daa813bed",
      "metadata": {
        "id": "63782b55-1c38-4e44-b003-e57daa813bed"
      },
      "source": [
        "We'll load a pre-trained DNA Language Model, \"SpeciesLM\", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.\n",
        "\n",
        "The tokenizer and the model comes from the paper, \"Species-aware DNA language models capture regulatory elements and their evolution\". [Paper Link](https://www.biorxiv.org/content/10.1101/2023.01.26.525670v2), [Code Link](https://github.com/gagneurlab/SpeciesLM). They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "dac961f4-c450-4124-923e-f4ba9bbd5e07",
      "metadata": {
        "id": "dac961f4-c450-4124-923e-f4ba9bbd5e07"
      },
      "outputs": [],
      "source": [
        "from transformers import AutoTokenizer, AutoModelForMaskedLM"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "e73fae58-03e9-4acc-b0fc-9bc810c7d366",
      "metadata": {
        "id": "e73fae58-03e9-4acc-b0fc-9bc810c7d366"
      },
      "outputs": [],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(\"gagneurlab/SpeciesLM\", revision = \"downstream_species_lm\")\n",
        "lm = AutoModelForMaskedLM.from_pretrained(\"gagneurlab/SpeciesLM\", revision = \"downstream_species_lm\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "id": "ca43b893-2d66-4e93-a08f-b17a92040709",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ca43b893-2d66-4e93-a08f-b17a92040709",
        "outputId": "ccbac964-a329-414d-f537-3cae7da66cf2"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "BertForMaskedLM(\n",
              "  (bert): BertModel(\n",
              "    (embeddings): BertEmbeddings(\n",
              "      (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
              "      (position_embeddings): Embedding(512, 768)\n",
              "      (token_type_embeddings): Embedding(2, 768)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "    (encoder): BertEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0-11): 12 x BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSdpaSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (cls): BertOnlyMLMHead(\n",
              "    (predictions): BertLMPredictionHead(\n",
              "      (transform): BertPredictionHeadTransform(\n",
              "        (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "        (transform_act_fn): GELUActivation()\n",
              "        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      )\n",
              "      (decoder): Linear(in_features=768, out_features=5504, bias=True)\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "lm.eval()\n",
        "lm.to(\"cuda\");"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "c1bda6f2-34bb-4ce2-aa3f-3013548b0a28",
      "metadata": {
        "id": "c1bda6f2-34bb-4ce2-aa3f-3013548b0a28"
      },
      "source": [
        "### 2. Prepare datasets"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f4c61e59-457c-47d9-8929-5e8cd32d3125",
      "metadata": {
        "id": "f4c61e59-457c-47d9-8929-5e8cd32d3125"
      },
      "source": [
        "We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "id": "f5c0b3df-911a-4645-9140-99ee489515e8",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 145,
          "referenced_widgets": [
            "03bba232d3974119acf8031bc086a072",
            "9107f7bfc8d3483390f802b0458e9380",
            "f5c80fa70ead4c86aa3b2a046061b901",
            "57966a469ca1458daab74e81672ae855",
            "1464502dc3dd46308be8b4fcc9d5ddb9",
            "92f64c7e088342b9b3c070ba7a295ed0",
            "ab0aa8af3816422e9d97934f12af842c",
            "ff89a891bd9c42a8be164587a94ccac1",
            "e113a50f8ed2410ca12ce7cb38a1681d",
            "1afa6e9b69c74136863b7747e62a0608",
            "0838d19b226d486285a26ce0b04d7e15",
            "7bdab33f4b244fc89408b91755bf17c5",
            "4d4ce0d35c124690b3427e84a9a128b1",
            "33be6b0ca8fd44188f834a48a9574a72",
            "74e9bc1ead434ae78077df6b85f1df58",
            "e1acc6e70b9246a5b063b3e262f01c81",
            "078c6877377a491d97d6fadd27064a76",
            "d46ee1c39bac44c2b541a88c883de1cb",
            "12f1de7122a7471e90f01d9e7be81178",
            "dad286d42a514c9ca6bb01bfe9e9c4be",
            "c028ed977b5e479fbd93b8add588a6dc",
            "6d80dec073e449efba272fa9f3527922",
            "c311b777514f41ef986756a386c0bb34",
            "e2e4bf053ce442f6aee6ffab5f76525f",
            "c88cf701e20b4354a63ac7d8645d1df9",
            "f71c252ada474be882b0335ed9a0a1c3",
            "e059c665229e46ea905dcbd6fc179c88",
            "bd5273325a4b453e8053d98a09fe9493",
            "8f20ed2b74d84e80a8d403793354adea",
            "57c9af47364d48ffbb4ffbdd2c951ede",
            "fa9d75fcb1d5400c8ca1d1d13d28d0c7",
            "682644a713b145f0b2dcff99790c6d4d",
            "9b9b9d573d44464f9a6f5030a40245fe",
            "ec165fdbe87a4b00a6c288ef1e85c0a9",
            "17859b793a304e389d1ea0b9ccc3646f",
            "34921fd116cc42b7b530174d9f61e71e",
            "2d5466a5e98849c5a09f16faa98f91da",
            "952397f9c91c480184fa57e175ab1b4c",
            "86bcccb842244f4f9add58f62facaace",
            "78b5bbf4c8ac4fe5961776fded4d5798",
            "c80062a855cb41a28ac625ab03635da2",
            "aecd740c17c84d45b0615d4fc4196035",
            "39640709e7174f84a50da05764abbf99",
            "7114a029e75c4ed5b966eddd3a3c919d"
          ]
        },
        "id": "f5c0b3df-911a-4645-9140-99ee489515e8",
        "outputId": "15315be1-9d07-4c46-acda-c65cb5a05250"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "03bba232d3974119acf8031bc086a072",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading data:   0%|          | 0.00/3.50M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7bdab33f4b244fc89408b91755bf17c5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading data:   0%|          | 0.00/391k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c311b777514f41ef986756a386c0bb34",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/13468 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ec165fdbe87a4b00a6c288ef1e85c0a9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating test split:   0%|          | 0/1497 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "raw_data = load_dataset(\"InstaDeepAI/nucleotide_transformer_downstream_tasks\", \"H3\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "bbb527c5-8077-4ce4-b093-ae627a5f253c",
      "metadata": {
        "id": "bbb527c5-8077-4ce4-b093-ae627a5f253c"
      },
      "source": [
        "We'll use the \"H3\" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "id": "efef4bb2-60d8-40d1-8777-2b665a87059c",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "efef4bb2-60d8-40d1-8777-2b665a87059c",
        "outputId": "1c8526ce-0fcb-4fbc-d592-f9a6eae6ebdb"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DatasetDict({\n",
              "    train: Dataset({\n",
              "        features: ['sequence', 'name', 'label'],\n",
              "        num_rows: 13468\n",
              "    })\n",
              "    test: Dataset({\n",
              "        features: ['sequence', 'name', 'label'],\n",
              "        num_rows: 1497\n",
              "    })\n",
              "})"
            ]
          },
          "execution_count": 14,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "raw_data"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "aafd37c8-6830-4070-a73b-cf62e72e901c",
      "metadata": {
        "id": "aafd37c8-6830-4070-a73b-cf62e72e901c"
      },
      "source": [
        "The dataset consists of three columns, ```sequence```, ```name``` and ```label```. An row in this dataset looks like:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "id": "eecd39d8-c073-4d3e-940e-fd83d46f83ab",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eecd39d8-c073-4d3e-940e-fd83d46f83ab",
        "outputId": "0b5f8800-eb5d-4c41-a2bc-73e4f837a4d8",
        "scrolled": true
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'sequence': 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC',\n",
              " 'name': 'YBR063C_YBR063C_367930|0',\n",
              " 'label': 0}"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "raw_data['train'][0]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "92eccf3e-e846-4c59-af56-0e336ac5a1cd",
      "metadata": {
        "id": "92eccf3e-e846-4c59-af56-0e336ac5a1cd"
      },
      "source": [
        "We split out dataset into training, test, and validation sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "id": "f0649bbd-e74e-4dd6-a564-c4d65e46dbbf",
      "metadata": {
        "id": "f0649bbd-e74e-4dd6-a564-c4d65e46dbbf"
      },
      "outputs": [],
      "source": [
        "from datasets import Dataset, DatasetDict\n",
        "\n",
        "train_valid_split = raw_data['train'].train_test_split(test_size=0.15, seed=42)\n",
        "\n",
        "train_valid_split = DatasetDict({\n",
        "    'train': train_valid_split['train'],\n",
        "    'validation': train_valid_split['test']\n",
        "})\n",
        "\n",
        "ds = DatasetDict({\n",
        "    'train': train_valid_split['train'],\n",
        "    'validation': train_valid_split['validation'],\n",
        "    'test': raw_data['test']\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "5424726f-a7ba-45d5-b449-36be9a98b8e6",
      "metadata": {
        "id": "5424726f-a7ba-45d5-b449-36be9a98b8e6"
      },
      "source": [
        "Then, we use the tokenizer and a utility function we created, ```get_kmers``` to generate the final data and labels. The ```get_kmers``` function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "id": "75f267a9-82d1-4343-982e-9b1ea542a330",
      "metadata": {
        "id": "75f267a9-82d1-4343-982e-9b1ea542a330"
      },
      "outputs": [],
      "source": [
        "def get_kmers(seq, k=6, stride=1):\n",
        "    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "id": "efa9441d-f44c-4ca3-b24c-fa5c853896cd",
      "metadata": {
        "id": "efa9441d-f44c-4ca3-b24c-fa5c853896cd"
      },
      "outputs": [],
      "source": [
        "test_sequences = []\n",
        "train_sequences = []\n",
        "val_sequences = []\n",
        "\n",
        "dataset_limit = 200 # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the\n",
        "                    # entire dataset\n",
        "\n",
        "for i in range(0, len(ds['train'])):\n",
        "\n",
        "    if dataset_limit and i == dataset_limit:\n",
        "        break\n",
        "\n",
        "    sequence = ds['train'][i]['sequence']\n",
        "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
        "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
        "    train_sequences.append(sequence)\n",
        "\n",
        "\n",
        "for i in range(0, len(ds['validation'])):\n",
        "    if dataset_limit and i == dataset_limit:\n",
        "        break\n",
        "    sequence = ds['validation'][i]['sequence']\n",
        "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
        "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
        "    val_sequences.append(sequence)\n",
        "\n",
        "\n",
        "for i in range(0, len(ds['test'])):\n",
        "    if dataset_limit and i == dataset_limit:\n",
        "        break\n",
        "    sequence = ds['test'][i]['sequence']\n",
        "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
        "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
        "    test_sequences.append(sequence)\n",
        "\n",
        "\n",
        "train_labels = ds['train']['label']\n",
        "test_labels = ds['test']['label']\n",
        "val_labels = ds['validation']['label']\n",
        "\n",
        "if dataset_limit:\n",
        "    train_labels = train_labels[0:dataset_limit]\n",
        "    test_labels = test_labels[0:dataset_limit]\n",
        "    val_labels = val_labels[0:dataset_limit]"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "0686955c-201a-427b-8bef-5c663edb85b8",
      "metadata": {
        "id": "0686955c-201a-427b-8bef-5c663edb85b8"
      },
      "source": [
        "Finally, we create a Dataset object for each our sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "id": "445b4279-2446-46d6-af2a-ceb2638955c7",
      "metadata": {
        "id": "445b4279-2446-46d6-af2a-ceb2638955c7"
      },
      "outputs": [],
      "source": [
        "from datasets import Dataset\n",
        "\n",
        "train_dataset = Dataset.from_dict({\"input_ids\": train_sequences, \"labels\": train_labels})\n",
        "val_dataset = Dataset.from_dict({\"input_ids\": val_sequences, \"labels\": val_labels})\n",
        "test_dataset = Dataset.from_dict({\"input_ids\": test_sequences, \"labels\": test_labels})"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "d05d51a7-b933-4793-95df-af7d4d510b13",
      "metadata": {
        "id": "d05d51a7-b933-4793-95df-af7d4d510b13"
      },
      "source": [
        "### 4. Train model"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "b5ce1985-c24e-4feb-a6d4-aacb909536f0",
      "metadata": {
        "id": "b5ce1985-c24e-4feb-a6d4-aacb909536f0"
      },
      "source": [
        "Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "id": "3a34b2c0-6205-4d48-b1a6-371b50ca42de",
      "metadata": {
        "id": "3a34b2c0-6205-4d48-b1a6-371b50ca42de"
      },
      "outputs": [],
      "source": [
        "from transformers import DataCollatorWithPadding\n",
        "\n",
        "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "id": "700540f4-0ab8-4f8a-a75c-416a6908af47",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "700540f4-0ab8-4f8a-a75c-416a6908af47",
        "outputId": "9e16c1e9-4676-4cdf-b2a9-d785773b1c8d"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DNA_LM(\n",
              "  (model): BertModel(\n",
              "    (embeddings): BertEmbeddings(\n",
              "      (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
              "      (position_embeddings): Embedding(512, 768)\n",
              "      (token_type_embeddings): Embedding(2, 768)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "    (encoder): BertEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0-11): 12 x BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSdpaSelfAttention(\n",
              "              (query): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (key): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (value): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
              ")"
            ]
          },
          "execution_count": 32,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "\n",
        "class DNA_LM(nn.Module):\n",
        "    def __init__(self, model, num_labels):\n",
        "        super(DNA_LM, self).__init__()\n",
        "        self.model = model.bert\n",
        "        self.in_features = model.config.hidden_size\n",
        "        self.out_features = num_labels\n",
        "        self.classifier = nn.Linear(self.in_features, self.out_features)\n",
        "\n",
        "    def forward(self, input_ids, attention_mask=None, labels=None):\n",
        "        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
        "        sequence_output = outputs.hidden_states[-1]\n",
        "        # Use the [CLS] token for classification\n",
        "        cls_output = sequence_output[:, 0, :]\n",
        "        logits = self.classifier(cls_output)\n",
        "\n",
        "        loss = None\n",
        "        if labels is not None:\n",
        "            loss_fct = nn.CrossEntropyLoss()\n",
        "            loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))\n",
        "\n",
        "        return (loss, logits) if loss is not None else logits\n",
        "\n",
        "# Number of classes for your classification task\n",
        "num_labels = 2\n",
        "classification_model = DNA_LM(lm, num_labels)\n",
        "classification_model.to('cuda');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "id": "0af97341-8f95-41d9-9d91-1eb64da4b516",
      "metadata": {
        "id": "0af97341-8f95-41d9-9d91-1eb64da4b516"
      },
      "outputs": [],
      "source": [
        "from transformers import DataCollatorWithPadding\n",
        "\n",
        "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 36,
      "id": "d9ce6bc3-4f63-4b7b-b28d-d2553002e6db",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 268
        },
        "id": "d9ce6bc3-4f63-4b7b-b28d-d2553002e6db",
        "outputId": "0c8fdbad-f34d-492b-e146-db6c2064e7c5"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='65' max='65' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [65/65 01:43, Epoch 5/5]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.887400</td>\n",
              "      <td>0.685295</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.644700</td>\n",
              "      <td>0.682495</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.599600</td>\n",
              "      <td>0.680431</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.892800</td>\n",
              "      <td>0.679170</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.663800</td>\n",
              "      <td>0.678761</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=65, training_loss=0.7263066686116733, metrics={'train_runtime': 104.8696, 'train_samples_per_second': 9.536, 'train_steps_per_second': 0.62, 'total_flos': 0.0, 'train_loss': 0.7263066686116733, 'epoch': 5.0})"
            ]
          },
          "execution_count": 36,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from transformers import Trainer, TrainingArguments\n",
        "\n",
        "\n",
        "# Define training arguments\n",
        "training_args = TrainingArguments(\n",
        "    output_dir='./results',\n",
        "    eval_strategy=\"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    num_train_epochs=5,\n",
        "    weight_decay=0.01,\n",
        "    eval_steps=1,\n",
        "    logging_steps=1,\n",
        ")\n",
        "\n",
        "# Initialize Trainer\n",
        "trainer = Trainer(\n",
        "    model=classification_model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    tokenizer=tokenizer,\n",
        "    data_collator=data_collator,\n",
        ")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ebc7e33a-caad-4412-84e3-3e1ce7d02ccd",
      "metadata": {
        "id": "ebc7e33a-caad-4412-84e3-3e1ce7d02ccd"
      },
      "source": [
        "### 5. Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 37,
      "id": "38eb0273-ce7e-4770-8457-2f9609f6843b",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 124
        },
        "id": "38eb0273-ce7e-4770-8457-2f9609f6843b",
        "outputId": "2b0b93c9-0199-4e71-9825-9f6a2bd199d0"
      },
      "outputs": [
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 0 1 1 0 1 1 0 1 1 1 1 0 1 0 0 0 1 1 0 1 1\n",
            " 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 0 1 0 0 1 1 1 1 1 0 0 0 1 0 1 1 0 1\n",
            " 0 1 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 1 1 1 0 1 0 0 0 0 1 0 1 0 0 1 1 1\n",
            " 1 0 1 1 0 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 1 1 0 1 1 0 1 1 0 1\n",
            " 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 1 1 1 1 1 1 0 1 1 1 0 0 1 1 1 1\n",
            " 0 1 1 1 1 0 1 1 0 0 1 0 1 1 0]\n"
          ]
        }
      ],
      "source": [
        "# Generate predictions\n",
        "\n",
        "predictions = trainer.predict(test_dataset)\n",
        "logits = predictions.predictions\n",
        "predicted_labels = logits.argmax(axis=-1)\n",
        "print(predicted_labels)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "ae4c7bca",
      "metadata": {
        "id": "ae4c7bca"
      },
      "source": [
        "Then, we create a function to calculate the accuracy from the test and predicted labels."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 38,
      "id": "327a1c3b-88d6-4430-8978-73a7cbdbb697",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "327a1c3b-88d6-4430-8978-73a7cbdbb697",
        "outputId": "f03ad54d-d35f-4fcc-e709-c24d14906e25"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.53\n"
          ]
        }
      ],
      "source": [
        "def calculate_accuracy(true_labels, predicted_labels):\n",
        "\n",
        "    assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
        "    correct_predictions = np.sum(true_labels == predicted_labels)\n",
        "    accuracy = correct_predictions / len(true_labels)\n",
        "\n",
        "    return accuracy\n",
        "\n",
        "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
        "print(f\"Accuracy: {accuracy:.2f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9p0fFXKTZz9Q",
      "metadata": {
        "id": "9p0fFXKTZz9Q"
      },
      "source": [
        "The results aren't that good, which we can attribute to the small dataset size."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "e681864c-f15a-40a6-ac34-0e631d68d5c8",
      "metadata": {
        "id": "e681864c-f15a-40a6-ac34-0e631d68d5c8"
      },
      "source": [
        "### 7. Parameter Efficient Fine-Tuning Techniques"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "9141fabe-417b-4fbb-bd3e-244ad84e3010",
      "metadata": {
        "id": "9141fabe-417b-4fbb-bd3e-244ad84e3010"
      },
      "source": [
        "In this section, we demonstrate how to employ parameter-efficient fine-tuning (PEFT) techniques to adapt a pre-trained model for specific genomics tasks using the PEFT library."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "71b8a749-461e-4533-b1d0-cebc924d3dc0",
      "metadata": {
        "id": "71b8a749-461e-4533-b1d0-cebc924d3dc0"
      },
      "source": [
        "The LoraConfig object is instantiated to configure the PEFT parameters:\n",
        "\n",
        "- task_type: Specifies the type of task, in this case, sequence classification (SEQ_CLS).\n",
        "- r: The rank of the LoRA matrices.\n",
        "- lora_alpha: Scaling factor for adaptive re-parameterization.\n",
        "- target_modules: Modules within the model to apply PEFT re-parameterization (query, key, value in this example).\n",
        "- lora_dropout: Dropout rate used during PEFT fine-tuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 40,
      "id": "021641ae-f604-4d69-8724-743b7d7c613c",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "021641ae-f604-4d69-8724-743b7d7c613c",
        "outputId": "d7c41fca-1c6b-46fd-9116-01f42d1d6ddf"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "DNA_LM(\n",
              "  (model): BertModel(\n",
              "    (embeddings): BertEmbeddings(\n",
              "      (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
              "      (position_embeddings): Embedding(512, 768)\n",
              "      (token_type_embeddings): Embedding(2, 768)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.1, inplace=False)\n",
              "    )\n",
              "    (encoder): BertEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0-11): 12 x BertLayer(\n",
              "          (attention): BertAttention(\n",
              "            (self): BertSdpaSelfAttention(\n",
              "              (query): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (key): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (value): lora.Linear(\n",
              "                (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                (lora_dropout): ModuleDict(\n",
              "                  (default): Dropout(p=0.01, inplace=False)\n",
              "                )\n",
              "                (lora_A): ModuleDict(\n",
              "                  (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                )\n",
              "                (lora_B): ModuleDict(\n",
              "                  (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                )\n",
              "                (lora_embedding_A): ParameterDict()\n",
              "                (lora_embedding_B): ParameterDict()\n",
              "              )\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "            (output): BertSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.1, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): BertIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): BertOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.1, inplace=False)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
              ")"
            ]
          },
          "execution_count": 40,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Number of classes for your classification task\n",
        "num_labels = 2\n",
        "classification_model = DNA_LM(lm, num_labels)\n",
        "classification_model.to('cuda');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 41,
      "id": "6c223937-86ea-42ef-991a-050f23b21ef9",
      "metadata": {
        "id": "6c223937-86ea-42ef-991a-050f23b21ef9"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig, TaskType\n",
        "\n",
        "peft_config = LoraConfig(\n",
        "    r=8,\n",
        "    lora_alpha=32,\n",
        "    target_modules=[\"query\", \"key\", \"value\"],\n",
        "    lora_dropout=0.01,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 42,
      "id": "e7a9fe7d-e3ac-4ffa-9a9b-2067fb09b885",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e7a9fe7d-e3ac-4ffa-9a9b-2067fb09b885",
        "outputId": "02a6c65f-7474-4bc1-bfab-c05532e350a5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "trainable params: 442,368 || all params: 90,121,730 || trainable%: 0.4909\n"
          ]
        }
      ],
      "source": [
        "from peft import get_peft_model\n",
        "\n",
        "peft_model = get_peft_model(classification_model, peft_config)\n",
        "peft_model.print_trainable_parameters()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 43,
      "id": "22064519-eaab-4142-8618-d1210d05c6bd",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "22064519-eaab-4142-8618-d1210d05c6bd",
        "outputId": "ca3f764d-cdb4-4525-c541-8eabfb4cde57"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "PeftModel(\n",
              "  (base_model): LoraModel(\n",
              "    (model): DNA_LM(\n",
              "      (model): BertModel(\n",
              "        (embeddings): BertEmbeddings(\n",
              "          (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
              "          (position_embeddings): Embedding(512, 768)\n",
              "          (token_type_embeddings): Embedding(2, 768)\n",
              "          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "          (dropout): Dropout(p=0.1, inplace=False)\n",
              "        )\n",
              "        (encoder): BertEncoder(\n",
              "          (layer): ModuleList(\n",
              "            (0-11): 12 x BertLayer(\n",
              "              (attention): BertAttention(\n",
              "                (self): BertSdpaSelfAttention(\n",
              "                  (query): lora.Linear(\n",
              "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                    (lora_dropout): ModuleDict(\n",
              "                      (default): Dropout(p=0.01, inplace=False)\n",
              "                    )\n",
              "                    (lora_A): ModuleDict(\n",
              "                      (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                    )\n",
              "                    (lora_B): ModuleDict(\n",
              "                      (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                    )\n",
              "                    (lora_embedding_A): ParameterDict()\n",
              "                    (lora_embedding_B): ParameterDict()\n",
              "                  )\n",
              "                  (key): lora.Linear(\n",
              "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                    (lora_dropout): ModuleDict(\n",
              "                      (default): Dropout(p=0.01, inplace=False)\n",
              "                    )\n",
              "                    (lora_A): ModuleDict(\n",
              "                      (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                    )\n",
              "                    (lora_B): ModuleDict(\n",
              "                      (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                    )\n",
              "                    (lora_embedding_A): ParameterDict()\n",
              "                    (lora_embedding_B): ParameterDict()\n",
              "                  )\n",
              "                  (value): lora.Linear(\n",
              "                    (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
              "                    (lora_dropout): ModuleDict(\n",
              "                      (default): Dropout(p=0.01, inplace=False)\n",
              "                    )\n",
              "                    (lora_A): ModuleDict(\n",
              "                      (default): Linear(in_features=768, out_features=8, bias=False)\n",
              "                    )\n",
              "                    (lora_B): ModuleDict(\n",
              "                      (default): Linear(in_features=8, out_features=768, bias=False)\n",
              "                    )\n",
              "                    (lora_embedding_A): ParameterDict()\n",
              "                    (lora_embedding_B): ParameterDict()\n",
              "                  )\n",
              "                  (dropout): Dropout(p=0.1, inplace=False)\n",
              "                )\n",
              "                (output): BertSelfOutput(\n",
              "                  (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "                  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "                  (dropout): Dropout(p=0.1, inplace=False)\n",
              "                )\n",
              "              )\n",
              "              (intermediate): BertIntermediate(\n",
              "                (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "                (intermediate_act_fn): GELUActivation()\n",
              "              )\n",
              "              (output): BertOutput(\n",
              "                (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "                (dropout): Dropout(p=0.1, inplace=False)\n",
              "              )\n",
              "            )\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "      (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
              "    )\n",
              "  )\n",
              ")"
            ]
          },
          "execution_count": 43,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "peft_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 45,
      "id": "d3812e96-6b49-4911-8b21-d8871b7c06a5",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 268
        },
        "id": "d3812e96-6b49-4911-8b21-d8871b7c06a5",
        "outputId": "8d497e30-1d3f-457a-f62a-244731698cb2"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='65' max='65' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [65/65 01:39, Epoch 5/5]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.625700</td>\n",
              "      <td>0.777132</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.717200</td>\n",
              "      <td>0.773871</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.768200</td>\n",
              "      <td>0.771541</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>0.687400</td>\n",
              "      <td>0.769679</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>0.552000</td>\n",
              "      <td>0.768947</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=65, training_loss=0.74742647592838, metrics={'train_runtime': 100.8429, 'train_samples_per_second': 9.916, 'train_steps_per_second': 0.645, 'total_flos': 0.0, 'train_loss': 0.74742647592838, 'epoch': 5.0})"
            ]
          },
          "execution_count": 45,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Define training arguments\n",
        "training_args = TrainingArguments(\n",
        "    output_dir='./results',\n",
        "    eval_strategy=\"epoch\",\n",
        "    learning_rate=2e-5,\n",
        "    per_device_train_batch_size=16,\n",
        "    per_device_eval_batch_size=16,\n",
        "    num_train_epochs=5,\n",
        "    weight_decay=0.01,\n",
        "    eval_steps=1,\n",
        "    logging_steps=1,\n",
        ")\n",
        "\n",
        "# Initialize Trainer\n",
        "trainer = Trainer(\n",
        "    model=peft_model.model,\n",
        "    args=training_args,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=val_dataset,\n",
        "    tokenizer=tokenizer,\n",
        "    data_collator=data_collator,\n",
        ")\n",
        "\n",
        "# Train the model\n",
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "76dbd948-d919-4ade-a405-cec297979577",
      "metadata": {
        "id": "76dbd948-d919-4ade-a405-cec297979577"
      },
      "source": [
        "### 8. Evaluate PEFT Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 46,
      "id": "58cf70ba-47d5-4111-bb12-830ae04c6285",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 124
        },
        "id": "58cf70ba-47d5-4111-bb12-830ae04c6285",
        "outputId": "0abc56a9-bd68-4e4e-9f13-756e8c9ffa3e"
      },
      "outputs": [
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[1 0 1 0 0 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1\n",
            " 1 1 1 0 1 1 0 1 0 0 1 0 0 1 1 0 1 1 0 0 1 1 0 0 1 1 0 0 0 0 0 0 0 1 1 0 1\n",
            " 1 0 1 0 0 1 1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 1 1 0 1 1 0 1 0 0 1 1 0 1 1 1 0\n",
            " 1 1 0 0 1 0 1 1 1 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 1 1 1 0 1 0 1 0 1 1 0 1 1\n",
            " 0 1 1 1 1 1 1 1 0 1 1 0 1 0 0 0 0 0 0 1 1 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 0\n",
            " 0 1 1 0 0 0 1 0 1 1 1 0 1 1 0]\n"
          ]
        }
      ],
      "source": [
        "# Generate predictions\n",
        "\n",
        "predictions = trainer.predict(test_dataset)\n",
        "logits = predictions.predictions\n",
        "predicted_labels = logits.argmax(axis=-1)\n",
        "print(predicted_labels)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 47,
      "id": "4bd38fe5-6513-4c88-afee-0cc4e1781fdd",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4bd38fe5-6513-4c88-afee-0cc4e1781fdd",
        "outputId": "a50a91d0-d04d-4620-9006-868716bb992d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.52\n"
          ]
        }
      ],
      "source": [
        "def calculate_accuracy(true_labels, predicted_labels):\n",
        "\n",
        "    assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
        "    correct_predictions = np.sum(true_labels == predicted_labels)\n",
        "    accuracy = correct_predictions / len(true_labels)\n",
        "\n",
        "    return accuracy\n",
        "\n",
        "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
        "print(f\"Accuracy: {accuracy:.2f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4ba5af69",
      "metadata": {},
      "source": [
        "As we can see, the PEFT model achieves similar performance to the baseline model, demonstrating the effectiveness of PEFT in adapting pre-trained models to specific tasks with limited computational resources.\n",
        "\n",
        "With PEFT, we only train 442,368 parameters, which is 0.49% of the total parameters in the model. This is a significant reduction in computational resources compared to training the entire model from scratch.\n",
        "\n",
        "We can improve the results by using a larger dataset, fine-tuning the model for more epochs or changing the hyperparameters (rank, learning rate, etc.).\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.3"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "03bba232d3974119acf8031bc086a072": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_9107f7bfc8d3483390f802b0458e9380",
              "IPY_MODEL_f5c80fa70ead4c86aa3b2a046061b901",
              "IPY_MODEL_57966a469ca1458daab74e81672ae855"
            ],
            "layout": "IPY_MODEL_1464502dc3dd46308be8b4fcc9d5ddb9"
          }
        },
        "078c6877377a491d97d6fadd27064a76": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "0838d19b226d486285a26ce0b04d7e15": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "12f1de7122a7471e90f01d9e7be81178": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "1464502dc3dd46308be8b4fcc9d5ddb9": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "17859b793a304e389d1ea0b9ccc3646f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_86bcccb842244f4f9add58f62facaace",
            "placeholder": "​",
            "style": "IPY_MODEL_78b5bbf4c8ac4fe5961776fded4d5798",
            "value": "Generating test split: 100%"
          }
        },
        "1afa6e9b69c74136863b7747e62a0608": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "2d5466a5e98849c5a09f16faa98f91da": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_39640709e7174f84a50da05764abbf99",
            "placeholder": "​",
            "style": "IPY_MODEL_7114a029e75c4ed5b966eddd3a3c919d",
            "value": " 1497/1497 [00:00&lt;00:00, 41394.98 examples/s]"
          }
        },
        "33be6b0ca8fd44188f834a48a9574a72": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_12f1de7122a7471e90f01d9e7be81178",
            "max": 390606,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_dad286d42a514c9ca6bb01bfe9e9c4be",
            "value": 390606
          }
        },
        "34921fd116cc42b7b530174d9f61e71e": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c80062a855cb41a28ac625ab03635da2",
            "max": 1497,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_aecd740c17c84d45b0615d4fc4196035",
            "value": 1497
          }
        },
        "39640709e7174f84a50da05764abbf99": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "4d4ce0d35c124690b3427e84a9a128b1": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_078c6877377a491d97d6fadd27064a76",
            "placeholder": "​",
            "style": "IPY_MODEL_d46ee1c39bac44c2b541a88c883de1cb",
            "value": "Downloading data: 100%"
          }
        },
        "57966a469ca1458daab74e81672ae855": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_1afa6e9b69c74136863b7747e62a0608",
            "placeholder": "​",
            "style": "IPY_MODEL_0838d19b226d486285a26ce0b04d7e15",
            "value": " 3.50M/3.50M [00:00&lt;00:00, 26.3MB/s]"
          }
        },
        "57c9af47364d48ffbb4ffbdd2c951ede": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "682644a713b145f0b2dcff99790c6d4d": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6d80dec073e449efba272fa9f3527922": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7114a029e75c4ed5b966eddd3a3c919d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "74e9bc1ead434ae78077df6b85f1df58": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_c028ed977b5e479fbd93b8add588a6dc",
            "placeholder": "​",
            "style": "IPY_MODEL_6d80dec073e449efba272fa9f3527922",
            "value": " 391k/391k [00:00&lt;00:00, 3.34MB/s]"
          }
        },
        "78b5bbf4c8ac4fe5961776fded4d5798": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "7bdab33f4b244fc89408b91755bf17c5": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_4d4ce0d35c124690b3427e84a9a128b1",
              "IPY_MODEL_33be6b0ca8fd44188f834a48a9574a72",
              "IPY_MODEL_74e9bc1ead434ae78077df6b85f1df58"
            ],
            "layout": "IPY_MODEL_e1acc6e70b9246a5b063b3e262f01c81"
          }
        },
        "86bcccb842244f4f9add58f62facaace": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "8f20ed2b74d84e80a8d403793354adea": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "9107f7bfc8d3483390f802b0458e9380": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_92f64c7e088342b9b3c070ba7a295ed0",
            "placeholder": "​",
            "style": "IPY_MODEL_ab0aa8af3816422e9d97934f12af842c",
            "value": "Downloading data: 100%"
          }
        },
        "92f64c7e088342b9b3c070ba7a295ed0": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "952397f9c91c480184fa57e175ab1b4c": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "9b9b9d573d44464f9a6f5030a40245fe": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "ab0aa8af3816422e9d97934f12af842c": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "aecd740c17c84d45b0615d4fc4196035": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "bd5273325a4b453e8053d98a09fe9493": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c028ed977b5e479fbd93b8add588a6dc": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c311b777514f41ef986756a386c0bb34": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_e2e4bf053ce442f6aee6ffab5f76525f",
              "IPY_MODEL_c88cf701e20b4354a63ac7d8645d1df9",
              "IPY_MODEL_f71c252ada474be882b0335ed9a0a1c3"
            ],
            "layout": "IPY_MODEL_e059c665229e46ea905dcbd6fc179c88"
          }
        },
        "c80062a855cb41a28ac625ab03635da2": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "c88cf701e20b4354a63ac7d8645d1df9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_57c9af47364d48ffbb4ffbdd2c951ede",
            "max": 13468,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_fa9d75fcb1d5400c8ca1d1d13d28d0c7",
            "value": 13468
          }
        },
        "d46ee1c39bac44c2b541a88c883de1cb": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "dad286d42a514c9ca6bb01bfe9e9c4be": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e059c665229e46ea905dcbd6fc179c88": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e113a50f8ed2410ca12ce7cb38a1681d": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "e1acc6e70b9246a5b063b3e262f01c81": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "e2e4bf053ce442f6aee6ffab5f76525f": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_bd5273325a4b453e8053d98a09fe9493",
            "placeholder": "​",
            "style": "IPY_MODEL_8f20ed2b74d84e80a8d403793354adea",
            "value": "Generating train split: 100%"
          }
        },
        "ec165fdbe87a4b00a6c288ef1e85c0a9": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HBoxModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_17859b793a304e389d1ea0b9ccc3646f",
              "IPY_MODEL_34921fd116cc42b7b530174d9f61e71e",
              "IPY_MODEL_2d5466a5e98849c5a09f16faa98f91da"
            ],
            "layout": "IPY_MODEL_952397f9c91c480184fa57e175ab1b4c"
          }
        },
        "f5c80fa70ead4c86aa3b2a046061b901": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "FloatProgressModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_ff89a891bd9c42a8be164587a94ccac1",
            "max": 3495021,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_e113a50f8ed2410ca12ce7cb38a1681d",
            "value": 3495021
          }
        },
        "f71c252ada474be882b0335ed9a0a1c3": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "HTMLModel",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_682644a713b145f0b2dcff99790c6d4d",
            "placeholder": "​",
            "style": "IPY_MODEL_9b9b9d573d44464f9a6f5030a40245fe",
            "value": " 13468/13468 [00:00&lt;00:00, 193879.37 examples/s]"
          }
        },
        "fa9d75fcb1d5400c8ca1d1d13d28d0c7": {
          "model_module": "@jupyter-widgets/controls",
          "model_module_version": "1.5.0",
          "model_name": "ProgressStyleModel",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "ff89a891bd9c42a8be164587a94ccac1": {
          "model_module": "@jupyter-widgets/base",
          "model_module_version": "1.2.0",
          "model_name": "LayoutModel",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
